Conversation
|
pls try rebase #1403 This is the original issue I mentioned |
tianyu-l
left a comment
There was a problem hiding this comment.
Left some questions.
Could also address #1365 (comment)
- rebase and see if the non-persistent buffer
tokens_per_expertis causing trouble - manually try change
freqs_cisto non-persistent and see if the issue is still there. https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/llama4/model/model.py#L388
| Apply torch.compile to each TransformerBlock, which makes compilation efficient due to | ||
| repeated structure. Alternatively one can compile the whole model (after applying DP). | ||
| """ | ||
| torch._dynamo.config.fail_on_recompile_limit_hit = True |
There was a problem hiding this comment.
What is this for?
Other than this, it seems we can just apply the same function llama 3 uses.
There was a problem hiding this comment.
this is to loud error if we recompile more than 8 times (default). currently, we would just silently fallback to eager if it happens.
There was a problem hiding this comment.
should we do the same to Llama 3? If so we can still reuse this function
| self.w1, self.w2, self.w3, x, num_tokens_per_expert | ||
| ) | ||
|
|
||
| # TODO: keeping this for-loop implementation for comparison |
There was a problem hiding this comment.
staticmethod on user-defined classes can not be generically supported, I moved those out.
Could you explain more? Does it mean if we move them out, then torch.compile can trace them in the same graph as the caller module is in?
| self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) | ||
| self.use_grouped_mm = use_grouped_mm | ||
|
|
||
| @torch._dynamo.set_fullgraph(True) |
There was a problem hiding this comment.
what is this annotation for?
There was a problem hiding this comment.
Compiling the block with fullgraph=False could allow graph breaks to creep in silently with dynamo changes, and we wouldn't know about them until we manually inspect the graph or suspect QPS to have regressed.
This API to more granularly control the fullgraph argument of torch.compile, you can flip it on and off within a compiled region. In this case, we allow graph breaks between GroupedExperts.call and GroupedExperts.forward, i.e. allow graph break on the forward hooks from FSDP
There was a problem hiding this comment.
In addition to FSDP comms, EP a2a also happens before & after GroupedExperts.forward. Does it mean it's still not fine-grained enough to capture graphs in EP?
|
|
||
| # shape (bs*slen*top_k, dim) | ||
| routed_output = self.experts(routed_input, num_tokens_per_expert) | ||
| with torch._dynamo.set_fullgraph(False): |
There was a problem hiding this comment.
IIUC, this annotation is for the FSDP caused graph break, correct?
Can we possibly incur this in the apply_compile function. Technically this change is model-intrusively, despite being small.
There was a problem hiding this comment.
This API can't decorate GroupedExperts.call right now. If it's a problem, we can just compile MoE with fullgraph=False
Status
We don't have a good way in compile to specify fullgraph=True except for FSDP hooks at the moment. We can either leave it
fullgraph=Falseor just wrap the experts model code inset_fullgraph(False)/set_fullgraph(True).Repro
tested on debug model
NGPU=2 CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --parallelism.data_parallel_shard_degree=2 --parallelism.expert_parallel_degree=2 --training.compilelogs: https://gist.github.com/xmfan/41b822d9f09eb07fee62d684a061cec1
memory: 2.20GiB -> 1.42GiB
speedup: no big change, need to check with actual model